{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download and extract video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-10T07:17:36.965608Z",
     "start_time": "2018-09-10T07:17:36.806321Z"
    }
   },
   "outputs": [],
   "source": [
    "import cv2\n",
    "from pytube import YouTube\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-10T07:17:37.378352Z",
     "start_time": "2018-09-10T07:17:37.370617Z"
    }
   },
   "outputs": [],
   "source": [
    "save_dir = Path('../data/source/')\n",
    "save_dir.mkdir(exist_ok=True)\n",
    "\n",
    "img_dir = save_dir.joinpath('images')\n",
    "img_dir.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-10T07:17:40.030886Z",
     "start_time": "2018-09-10T07:17:37.729714Z"
    }
   },
   "outputs": [],
   "source": [
    "# Bruno Mars - That's What I Like\n",
    "yt = YouTube('https://www.youtube.com/watch?v=PMivT7MJ41M')\n",
    "yt.streams.first().download(save_dir, 'mv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-10T07:17:53.162745Z",
     "start_time": "2018-09-10T07:17:40.032189Z"
    }
   },
   "outputs": [],
   "source": [
    "cap = cv2.VideoCapture(str(save_dir.joinpath('mv.mp4')))\n",
    "i = 0\n",
    "while(cap.isOpened()):\n",
    "    flag, frame = cap.read()\n",
    "    if flag == False or i == 1000:\n",
    "        break\n",
    "    cv2.imwrite(str(img_dir.joinpath(f'img_{i:04d}.png')), frame)\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pose estimation (OpenPose)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-10T07:17:54.776521Z",
     "start_time": "2018-09-10T07:17:54.553254Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-10T07:17:55.192338Z",
     "start_time": "2018-09-10T07:17:55.173456Z"
    }
   },
   "outputs": [],
   "source": [
    "openpose_dir = Path('../src/pytorch_Realtime_Multi-Person_Pose_Estimation/')\n",
    "\n",
    "import sys\n",
    "sys.path.append(str(openpose_dir))\n",
    "sys.path.append('../src/utils')\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-10T07:17:56.164221Z",
     "start_time": "2018-09-10T07:17:55.748391Z"
    }
   },
   "outputs": [],
   "source": [
    "# openpose\n",
    "from network.rtpose_vgg import get_model\n",
    "from evaluate.coco_eval import get_multiplier, get_outputs\n",
    "\n",
    "# utils\n",
    "from openpose_utils import remove_noise, get_pose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-10T07:17:58.478351Z",
     "start_time": "2018-09-10T07:17:56.328212Z"
    }
   },
   "outputs": [],
   "source": [
    "weight_name = openpose_dir.joinpath('network/weight/pose_model.pth')\n",
    "\n",
    "model = get_model('vgg19')     \n",
    "model.load_state_dict(torch.load(weight_name))\n",
    "model = torch.nn.DataParallel(model).cuda()\n",
    "model.float()\n",
    "model.eval()\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-10T07:18:00.352813Z",
     "start_time": "2018-09-10T07:18:00.162131Z"
    }
   },
   "outputs": [],
   "source": [
    "img_path = sorted(img_dir.iterdir())[137]\n",
    "img = cv2.imread(str(img_path))\n",
    "shape_dst = np.min(img.shape[:2])\n",
    "# offset\n",
    "oh = (img.shape[0] - shape_dst) // 2\n",
    "ow = (img.shape[1] - shape_dst) // 2\n",
    "\n",
    "img = img[oh:oh+shape_dst, ow:ow+shape_dst]\n",
    "img = cv2.resize(img, (512, 512))\n",
    "          \n",
    "plt.imshow(img[:,:,[2, 1, 0]]) # BGR -> RGB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-10T07:18:08.018250Z",
     "start_time": "2018-09-10T07:18:01.110620Z"
    }
   },
   "outputs": [],
   "source": [
    "multiplier = get_multiplier(img)\n",
    "with torch.no_grad():\n",
    "    paf, heatmap = get_outputs(multiplier, img, model, 'rtpose')\n",
    "    \n",
    "r_heatmap = np.array([remove_noise(ht)\n",
    "                      for ht in heatmap.transpose(2, 0, 1)[:-1]])\\\n",
    "                     .transpose(1, 2, 0)\n",
    "heatmap[:, :, :-1] = r_heatmap\n",
    "param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}\n",
    "label = get_pose(param, heatmap, paf)\n",
    "\n",
    "plt.imshow(label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## make label images for pix2pix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-10T07:18:39.528422Z",
     "start_time": "2018-09-10T07:18:11.930068Z"
    }
   },
   "outputs": [],
   "source": [
    "test_img_dir = save_dir.joinpath('test_img')\n",
    "test_img_dir.mkdir(exist_ok=True)\n",
    "test_label_dir = save_dir.joinpath('test_label')\n",
    "test_label_dir.mkdir(exist_ok=True)\n",
    "\n",
    "for idx in tqdm(range(117, 117+4)):\n",
    "    img_path = img_dir.joinpath(f'img_{idx:04d}.png')\n",
    "    img = cv2.imread(str(img_path))\n",
    "    shape_dst = np.min(img.shape[:2])\n",
    "    oh = (img.shape[0] - shape_dst) // 2\n",
    "    ow = (img.shape[1] - shape_dst) // 2\n",
    "\n",
    "    img = img[oh:oh+shape_dst, ow:ow+shape_dst]\n",
    "    img = cv2.resize(img, (512, 512))\n",
    "    multiplier = get_multiplier(img)\n",
    "    with torch.no_grad():\n",
    "        paf, heatmap = get_outputs(multiplier, img, model, 'rtpose')\n",
    "    r_heatmap = np.array([remove_noise(ht)\n",
    "                      for ht in heatmap.transpose(2, 0, 1)[:-1]])\\\n",
    "                     .transpose(1, 2, 0)\n",
    "    heatmap[:, :, :-1] = r_heatmap\n",
    "    param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}\n",
    "    label = get_pose(param, heatmap, paf)\n",
    "    cv2.imwrite(str(test_img_dir.joinpath(f'img_{idx:04d}.png')), img)\n",
    "    cv2.imwrite(str(test_label_dir.joinpath(f'label_{idx:04d}.png')), label)\n",
    "    \n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3.6",
   "language": "python",
   "name": "python3.6"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
